package org.zjvis.datascience.spark.util;

import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Joiner;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.*;
import scala.Tuple2;

import java.sql.*;
import java.util.*;

/**
 * @description Spark任务发送工具类
 * @date 2021-12-23
 */
public class UtilTool {

    private static Config config = Config.getInstance("algorithm.properties");

    private static final String JDBC_DRIVER_MYSQL = "com.mysql.jdbc.Driver";

    private static final String DB_URL_MYSQL = Config.DB_URL_MYSQL;

    private static final String USER_MYSQL = Config.USER_MYSQL;

    private static final String PASSWORD_MYSQL = Config.PASSWORD_MYSQL;

    private static final String JDBC_DRIVER_GP = "com.pivotal.jdbc.GreenplumDriver";

    private static final String DB_URL_GP = Config.DB_URL_GP;

    private static final String USER_GP = Config.USER_GP;

    private static final String PASSWORD_GP = Config.PASSWORD_GP;

    public static final String SELECT_SQL = "select %s from %s";

    public static final String DEFAULT_ID_COL = "_record_id_";

    public static final String MODEL_PATH = "hdfs:///zjvis/modelResult/%s_%s";

    public static final String RUN_META_PATH = "hdfs:///zjvis/runtimeMeta/%s_%s";

    public static Map<String, String> gpInfoMap;

    static {
        gpInfoMap = new HashMap<>();
        gpInfoMap.put("url", Config.DB_URL_GP);
        gpInfoMap.put("user", Config.USER_GP);
        gpInfoMap.put("password", Config.PASSWORD_GP);
        gpInfoMap.put("dbschema", "dataset");
    }

    public static String buildSelectSql(String[] featureCols, String table, String idCol,
                                        String[] otherCols) {
        String sql;
        String[] cols;
        if (otherCols == null || otherCols.length == 0) {
            cols = featureCols;
        } else {
            cols = new String[featureCols.length + otherCols.length];
            int i = 0;
            for (String key : featureCols) {
                cols[i++] = key;
            }
            for (String key : otherCols) {
                cols[i++] = key;
            }
        }
        if (StringUtils.isEmpty(idCol)) {
            sql = String.format(SELECT_SQL, Joiner.on(",").join(cols), table);
        } else {
            sql = String.format(SELECT_SQL, idCol + "," + Joiner.on(",").join(cols), table);
        }
        return sql;
    }

    public static String buildSelectSql(String[] featureCols, String table, String idCol) {
        return buildSelectSql(featureCols, table, idCol, null);
    }

    public static JSONObject buildOutputResult() {
        return null;
    }

    public static Map<String, DataType> getDataTypeMaps(StructType structType) {
        return getDataTypeMaps(structType, false);
    }

    public static Map<String, DataType> getDataTypeMaps(StructType structType,
                                                        boolean isDecimalToDouble) {
        Map<String, DataType> map = new HashMap<>();
        if (structType == null || structType.length() == 0) {
            return map;
        }
        StructField[] fields = structType.fields();
        for (StructField field : fields) {
            if (isDecimalToDouble) {
                if (field.dataType() instanceof DecimalType) {
                    map.put(field.name(), DataTypes.DoubleType);
                    continue;
                }
            }
            map.put(field.name(), field.dataType());
        }
        return map;
    }

    public static StructType createSchema(Map<String, DataType> typeMap, List<String> keys) {
        return createSchema(typeMap, keys, 0);
    }

    public static StructType createSchema(Map<String, DataType> typeMap, List<String> keys, int k) {
        return createSchema(typeMap, keys, k, null);
    }

    public static StructType createSchema(Map<String, DataType> typeMap, List<String> keys, int k,
                                          Map<String, DataType> otherMap) {
        List<StructField> structFields = new ArrayList<>();
        for (String key : keys) {
            structFields.add(DataTypes.createStructField(key, typeMap.get(key), true));
        }
        if (k != 0) {
            for (int i = 1; i <= k; ++i) {
                structFields.add(DataTypes
                        .createStructField(String.format("f%s", i), DataTypes.DoubleType, true));
            }
        }
        if (otherMap != null) {
            for (Map.Entry<String, DataType> entry : otherMap.entrySet()) {
                structFields
                        .add(DataTypes.createStructField(entry.getKey(), entry.getValue(), true));
            }
        }
        StructType schema = DataTypes.createStructType(structFields);
        return schema;
    }


    private static Map<String, String> gpInfo() {
        return gpInfoMap;
    }

    public static void setSchema(String schema) {
        gpInfoMap.put("dbschema", schema);
    }

    public static Dataset<Row> readFromGreenPlum(SparkSession sparkSession, String table,
                                                 String partitionCol) {
        return readFromGreenPlum(sparkSession, table, partitionCol, -1);
    }

    public static Dataset<Row> readFromGreenPlum(SparkSession sparkSession, String table,
                                                 String partitionCol, int sampleNumber) {
        String partitionColumn = partitionCol;
        if (StringUtils.isEmpty(partitionCol)) {
            partitionColumn = DEFAULT_ID_COL;
        }
        setSchema(table.split("\\.")[0]);
        Dataset<Row> dataset;
        if (sampleNumber > 0) {
            dataset = sparkSession.read().format("greenplum")
                    .options(gpInfo()).option("dbtable", table.split("\\.")[1])
                    .option("partitionColumn", partitionColumn)
                    .load().limit(sampleNumber).coalesce(1);
        } else {
            dataset = sparkSession.read().format("greenplum")
                    .options(gpInfo()).option("dbtable", table.split("\\.")[1])
                    .option("partitionColumn", partitionColumn)
                    .load();
        }
        return dataset;
    }

    public static List<Column> selectColumns(String[] cols, String idCol) {
        List<Column> columnList = new ArrayList<>();
        if (StringUtils.isNotEmpty(idCol)) {
            columnList.add(new Column(idCol));
        }
        for (String col : cols) {
            columnList.add(new Column(col));
        }
        return columnList;
    }

    public static Map<String, Tuple2<DataType, DataType>> getNeedChangeDoubleTypeKeys(
            String[] featureCols, Map<String, DataType> schema) {
        Map<String, Tuple2<DataType, DataType>> needChangeTypes = new HashMap<>();
        for (String colName : featureCols) {
            if (schema.get(colName) != DataTypes.DoubleType) {
                needChangeTypes
                        .put(colName, new Tuple2<>(schema.get(colName), DataTypes.DoubleType));
                schema.put(colName, DataTypes.DoubleType);
            }
        }
        return needChangeTypes;
    }

    public static Dataset<Row> changeBackDataTypeForKey(Dataset<Row> df,
                                                        Map<String, Tuple2<DataType, DataType>> needChangeTypes) {
        // change type back for col
        Dataset<Row> resultDF = df;
        Iterator<Map.Entry<String, Tuple2<DataType, DataType>>> iterator = needChangeTypes
                .entrySet().iterator();

        while (iterator.hasNext()) {
            Map.Entry<String, Tuple2<DataType, DataType>> entry = iterator.next();
            String colName = entry.getKey();
            DataType rawType = entry.getValue()._1;
//            DataType newType = entry.getValue()._2;
            resultDF = resultDF
                    .withColumn(String.format("%s_tmp", colName), resultDF.col(colName).cast(rawType))
                    .drop(colName)
                    .withColumnRenamed(String.format("%s_tmp", colName), colName);
        }
        return resultDF;
    }

    public static boolean saveGreenplumTable(Dataset<Row> dataset, String tableName) {
        setSchema(tableName.split("\\.")[0]);

        dataset.write().format("greenplum")
                .options(gpInfo()).option("dbtable", tableName.split("\\.")[1])
                .mode(SaveMode.Append)
                .save();
        return true;
    }

    public static Connection getConn(boolean isGp) {
        Connection conn = null;
        String driverStr = JDBC_DRIVER_MYSQL;
        String dbUrl = DB_URL_MYSQL;
        String user = USER_MYSQL;
        String password = PASSWORD_MYSQL;
        if (isGp) {
            driverStr = JDBC_DRIVER_GP;
            dbUrl = DB_URL_GP;
            user = USER_GP;
            password = PASSWORD_GP;
        }
        try {
            Class.forName(driverStr);
            conn = DriverManager.getConnection(dbUrl, user, password);
        } catch (Exception e) {
            e.printStackTrace();
            return conn;
        }
        return conn;
    }

    public static Connection getConn() {
        return getConn(false);
    }

    public static Map<String, String> getTableMetaForGp(String tableName) {
        Map<String, String> meta = new HashMap<>();
        Connection conn = getConn(true);
        if (conn != null) {
            String sql = String.format("SELECT * FROM %s WHERE 1=0", tableName);
            try {
                Statement stat = conn.createStatement();
                ResultSetMetaData metaData = stat.executeQuery(sql).getMetaData();
                int columnCount = metaData.getColumnCount();
                for (int i = 0; i < columnCount; i++) {
                    meta.put(metaData.getColumnName(i + 1), metaData.getColumnTypeName(i + 1));
                }
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                try {
                    conn.close();
                } catch (SQLException e) {
                }
            }
        }
        return meta;
    }

    public static int dropTable(String tableName) {
        Connection conn = getConn(true);
        if (conn != null) {
            String sql = String.format("drop table if exists %s", tableName);
            try {
                Statement stat = conn.createStatement();
                stat.executeQuery(sql);
            } catch (Exception e) {
                e.printStackTrace();
                return -1;
            } finally {
                try {
                    conn.close();
                } catch (SQLException e) {
                    return -2;
                }
            }
        }
        return 1;
    }
}
